#!/usr/bin/env python3
import os
import json
import jsonlines
import multiprocessing


class AGTR:
    """
    Class for computing bounds on precision, recall, accuracy, and error rate
    using an AGTR.
    """

    def __init__(self, pred_jsonl, agtr_jsonl, e_hat_m,
                 pred_format=("id", "label"), agtr_format=("id", "label")):
        """
        pred_jsonl -- .jsonl file containing predicted labels
        agtr_jsonl -- .jsonl file containing AGTR labels 
        e_hat_m -- Value of e hat / m
        pred_format -- Variable names in predicted .jsonl file
        agtr_format -- Variable names in AGTR .jsonl file
        """

        self.C_rev = self.__get_clusters_rev(pred_jsonl, pred_format)
        self.R_hat_rev = self.__get_clusters_rev(agtr_jsonl, agtr_format)
        if set(self.C_rev.keys()) != set(self.R_hat_rev.keys()):
            raise ValueError("IDs in pred_jsonl and agtr_jsonl do not match!")
        self.C = self.__get_clusters(self.C_rev)
        self.R_hat = self.__get_clusters(self.R_hat_rev)
        self.e_hat_m = e_hat_m
        self.m = len(self.C_rev.keys())


    def __get_clusters_rev(self, jsonl_file, jsonl_format):
        """Reads a .jsonl file with the format:

        {id_name: id_, label_name: label}
        {id_name: id_, label_name: label}
        {id_name: id_, label_name: label}
        ...

        Returns a reverse cluster mapping of the form {id_: label}.

        Arguments:
        jsonl_file -- The path to the .jsonl file
        jsonl_format -- A tuple containing (id_name, label_name)
        """

        if not os.path.isfile(jsonl_file):
            raise ValueError("No such file: {}".format(jsonl_file))

        (id_name, label_name) = jsonl_format

        # Parse id label pairs from jsonl
        cluster_rev = {}
        with jsonlines.open(jsonl_file) as reader:
            for obj in reader.iter(type=dict, skip_empty=True):
                if id_name not in obj:
                    raise ValueError("JSON missing {}".format(id_name))
                if label_name not in obj:
                    raise ValueError("JSON missing {}".format(label_name))
                id_ = obj[id_name]
                label = obj[label_name]
                # If label is None, assign to a singleton cluster
                if label is None:
                    label = id_
                cluster_rev[id_] = label

        return cluster_rev


    def __get_clusters(self, cluster_rev):
        """Returns a clustering of the form {label: {ids}}."""

        clusters = {}
        for id_, label in cluster_rev.items():
            if clusters.get(label) is None:
                clusters[label] = set()
            clusters[label].add(id_)

        return clusters


    def get_precision_bound(self):
        """Returns the precision lower bound."""

        bound = 0.0
        for ids in self.C.values():
            bound += max([len(ids.intersection(
                self.R_hat[self.R_hat_rev[id_]])) for id_ in ids])
        bound /= self.m
        bound -= self.e_hat_m

        return max(bound, 0.0)


    def get_recall_bound(self):
        """Returns the recall upper bound."""

        bound = 0.0
        for ids in self.R_hat.values():
            bound += max([len(ids.intersection(self.C[self.C_rev[id_]]))
                          for id_ in ids])
        bound /= self.m
        bound += self.e_hat_m

        return min(bound, 1.0)


    def get_accuracy_bound(self):
        """Returns the accuracy upper bound."""

        return self.get_recall_bound()


    def get_error_rate_bound(self):
        """Returns the error rate lower bound."""

        return 1 - self.get_recall_bound()
